PyTorch দিয়ে GAN তৈরি

Generative Adversarial Networks (GANs) - পাইটর্চ (Pytorch) - Machine Learning

369

Generative Adversarial Network (GAN) একটি জনপ্রিয় ডিপ লার্নিং মডেল যা দুটি নিউরাল নেটওয়ার্ক (জেনারেটর এবং ডিসক্রিমিনেটর) একে অপরের বিরুদ্ধে প্রতিদ্বন্দ্বিতা করে। এটি Generative Modeling এর জন্য ব্যবহৃত হয় এবং এটি নতুন, বাস্তবসম্মত ডেটা তৈরি করতে সক্ষম।

GAN এর মূল উপাদান:

  1. জেনারেটর (Generator): জেনারেটর একটি নতুন ডেটা তৈরি করে, যেমন একটি নতুন চিত্র। এটি র্যান্ডম নইস (noise) ইনপুট নেয় এবং সেটি একটি বাস্তব ডেটা প্যাটার্নে রূপান্তর করে।
  2. ডিসক্রিমিনেটর (Discriminator): ডিসক্রিমিনেটর নতুন এবং আসল ডেটার মধ্যে পার্থক্য বুঝতে চেষ্টা করে। এটি বাস্তব ডেটাকে 1 এবং জেনারেটরের তৈরি ডেটাকে 0 হিসাবে শ্রেণীবদ্ধ করে।

GAN এর উদ্দেশ্য হলো জেনারেটর এমনভাবে প্রশিক্ষিত হবে যাতে সে ডিসক্রিমিনেটরকে বিভ্রান্ত করতে পারে এবং বাস্তব ডেটার মতো ডেটা তৈরি করতে পারে।


PyTorch দিয়ে GAN তৈরি করার প্রক্রিয়া:

এখানে আমরা একটি বেসিক GAN মডেল তৈরি করব যা MNIST (হস্তলিখিত ডিজিট) ডেটাসেট ব্যবহার করবে।

১. লাইব্রেরি ইনস্টলেশন

PyTorch এবং অন্যান্য প্রয়োজনীয় লাইব্রেরি ইনস্টল করতে:

pip install torch torchvision matplotlib

২. জেনারেটর এবং ডিসক্রিমিনেটর তৈরি করা

প্রথমে Generator এবং Discriminator মডেল তৈরি করতে হবে। GAN এর মূল লক্ষ্য হলো জেনারেটরকে প্রশিক্ষিত করা যাতে সে ডিসক্রিমিনেটরকে বিভ্রান্ত করে।

import torch
import torch.nn as nn

# Generator মডেল
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 1024)
        self.fc5 = nn.Linear(1024, 28 * 28)  # MNIST চিত্রের আকার
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):
        x = self.relu(self.fc1(z))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.tanh(self.fc5(x))  # Channeled image output
        return x.view(-1, 1, 28, 28)  # reshape to 28x28 image

# Discriminator মডেল
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten input
        x = self.leaky_relu(self.fc1(x))
        x = self.leaky_relu(self.fc2(x))
        x = self.leaky_relu(self.fc3(x))
        x = self.sigmoid(self.fc4(x))  # Probability of real vs fake
        return x
  • Generator: এটি একটি fully connected নেটওয়ার্ক যা র্যান্ডম নইস (random noise) ইনপুট নিয়ে একটি চিত্র তৈরি করে।
  • Discriminator: এটি একটি fully connected নেটওয়ার্ক যা একটি চিত্র নেয় এবং সেটি আসল না জেনারেটেড তা শ্রেণীভুক্ত করে।

৩. ডেটাসেট লোড করা (MNIST)

MNIST ডেটাসেটের ইনপুট চিত্র 28x28 পিক্সেল সাইজের, তাই জেনারেটর এবং ডিসক্রিমিনেটর উভয়কেই এই আকারে ডেটা গ্রহণ ও উৎপন্ন করতে হবে।

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# MNIST ডেটাসেট লোডিং
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

৪. অপটিমাইজার এবং ক্রস এন্ট্রপি লস ফাংশন

GAN মডেলকে প্রশিক্ষণের জন্য আমরা Binary Cross Entropy Loss ব্যবহার করব, যেখানে Discriminator বাস্তব এবং জেনারেটেড চিত্রের মধ্যে পার্থক্য করবে এবং Generator চেষ্টা করবে ডিসক্রিমিনেটরকে বিভ্রান্ত করার জন্য।

# Loss function
criterion = nn.BCELoss()

# Optimizer (Adam)
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

generator = Generator(z_dim=100).cuda()
discriminator = Discriminator().cuda()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

৫. ট্রেনিং লুপ

এখন আমরা ট্রেনিং লুপ তৈরি করব যেখানে Generator এবং Discriminator একে অপরের বিরুদ্ধে প্রতিদ্বন্দ্বিতা করবে। প্রথমে, Discriminator আসল চিত্র এবং জেনারেটেড চিত্রের মধ্যে পার্থক্য করবে, তারপর Generator চেষ্টা করবে তার তৈরি চিত্র ডিসক্রিমিনেটরকে বিভ্রান্ত করতে।

import torch
import numpy as np
import matplotlib.pyplot as plt

# Noise generation function
def generate_noise(batch_size, z_dim):
    return torch.randn(batch_size, z_dim).cuda()

# Train the GAN model
num_epochs = 10
z_dim = 100

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        real_images = real_images.cuda()

        # Create labels
        real_labels = torch.ones(real_images.size(0), 1).cuda()
        fake_labels = torch.zeros(real_images.size(0), 1).cuda()

        # Train Discriminator
        optimizer_D.zero_grad()

        outputs = discriminator(real_images)
        d_loss_real = criterion(outputs, real_labels)

        noise = generate_noise(real_images.size(0), z_dim)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())  # Detach to avoid training generator
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

        if i % 200 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    # Generate and save fake images
    with torch.no_grad():
        fake_images = generator(generate_noise(16, z_dim))
        fake_images = fake_images.cpu().data
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        for i in range(4):
            for j in range(4):
                axes[i, j].imshow(fake_images[i*4+j].squeeze(), cmap='gray')
                axes[i, j].axis('off')
        plt.show()

৬. ট্রেনিং এর সময় চিত্র দেখানো

ট্রেনিং চলাকালীন সময়ে, Generator দ্বারা তৈরি চিত্র গুলি দেখানোর জন্য matplotlib ব্যবহার করতে পারেন।


সারাংশ

এই উদাহরণে আমরা একটি সাধারণ GAN মডেল তৈরি করেছি যা MNIST ডেটাসেটের জন্য প্রশিক্ষিত। Generator নতুন চিত্র তৈরি করে এবং Discriminator এই চিত্রের প্রকৃততা যাচাই করে। generator এবং discriminator মডেলগুলি Adversarial Training এর মাধ্যমে একে অপরকে শক্তিশালী করে।

  • Generator নতুন ডেটা তৈরি করে।
  • Discriminator বাস্তব এবং জেনারেটেড ডেটার মধ্যে পার্থক্য খুঁজে বের করে।

GAN মডেল ট্রেনিং অনেক সময় নেয়, তবে এর মাধ্যমে অসাধারণ বাস্তবসম্মত ডেটা

তৈরি করা সম্ভব।

Content added By
Promotion

Are you sure to start over?

Loading...